-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Adjust torch.compile() best practices #3336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
1. Add best practice to prefer `mod.compile` over `torch.compile(mod)`, which avoids `_orig_` naming problems. Repro steps: - opt_mod = torch.compile(mod) - train opt_mod - save checkpoint In another script, potentially on a machine that does NOT support `torch.compile`: load checkpoint. This fails with an error, because the checkpoint on `opt_mod` got its params renamed by `torch.compile`: ``` RuntimeError: Error(s) in loading state_dict for VQVAE: Missing key(s) in state_dict: "embedding.weight", "encoder.encoder.net.0.weight", "encoder.encoder.net.0.bias", ... Unexpected key(s) in state_dict: "_orig_mod.embedding.weight", "_orig_mod.encoder.encoder.net.0.weight", "_orig_mod.encoder.encoder.net.0.bias", ... ``` - Add best practice to use, or at least try, `fullgraph=True`. This doesn't always work, but we should encourage it. Note: I'm not a PyTorch expert, these are just based on footguns I've encountered over the past week.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/3336
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 23b81b8 with merge base a5632da ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @punkeel! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Description
mod.compile
overtorch.compile(mod)
, which avoids_orig_
naming problems. Repro steps:In another script, potentially on a machine that does NOT support
torch.compile
: load checkpoint.This fails with an error, because the checkpoint on
opt_mod
got its params renamed bytorch.compile
:fullgraph=True
. This doesn't always work, but we should encourage it.Note: I'm not a PyTorch expert, these are just based on footguns I've encountered over the past week.
Checklist
cc @williamwen42 @msaroufim @anijain2305